// Copyright 2015 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "base/profiler/native_stack_sampler.h"

#include <objbase.h>
#include <stddef.h>
#include <windows.h>
#include <winternl.h>

#include <cstdlib>
#include <map>
#include <memory>
#include <utility>
#include <vector>

#include "base/lazy_instance.h"
#include "base/logging.h"
#include "base/macros.h"
#include "base/profiler/win32_stack_frame_unwinder.h"
#include "base/strings/string_util.h"
#include "base/strings/stringprintf.h"
#include "base/strings/utf_string_conversions.h"
#include "base/time/time.h"
#include "base/win/pe_image.h"
#include "base/win/scoped_handle.h"

namespace base {

// Stack recording functions --------------------------------------------------

namespace {

    // The thread environment block internal type.
    struct TEB {
        NT_TIB Tib;
        // Rest of struct is ignored.
    };

    // Returns the thread environment block pointer for |thread_handle|.
    const TEB* GetThreadEnvironmentBlock(HANDLE thread_handle)
    {
        // Define the internal types we need to invoke NtQueryInformationThread.
        enum THREAD_INFORMATION_CLASS { ThreadBasicInformation };

        struct CLIENT_ID {
            HANDLE UniqueProcess;
            HANDLE UniqueThread;
        };

        struct THREAD_BASIC_INFORMATION {
            NTSTATUS ExitStatus;
            TEB* Teb;
            CLIENT_ID ClientId;
            KAFFINITY AffinityMask;
            LONG Priority;
            LONG BasePriority;
        };

        using NtQueryInformationThreadFunction = NTSTATUS(WINAPI*)(HANDLE, THREAD_INFORMATION_CLASS, PVOID, ULONG,
            PULONG);

        const NtQueryInformationThreadFunction nt_query_information_thread = reinterpret_cast<NtQueryInformationThreadFunction>(
            ::GetProcAddress(::GetModuleHandle(L"ntdll.dll"),
                "NtQueryInformationThread"));
        if (!nt_query_information_thread)
            return nullptr;

        THREAD_BASIC_INFORMATION basic_info = { 0 };
        NTSTATUS status = nt_query_information_thread(thread_handle, ThreadBasicInformation,
            &basic_info, sizeof(THREAD_BASIC_INFORMATION),
            nullptr);
        if (status != 0)
            return nullptr;

        return basic_info.Teb;
    }

#if defined(_WIN64)
    // If the value at |pointer| points to the original stack, rewrite it to point
    // to the corresponding location in the copied stack.
    void RewritePointerIfInOriginalStack(uintptr_t top, uintptr_t bottom,
        void* stack_copy, const void** pointer)
    {
        const uintptr_t value = reinterpret_cast<uintptr_t>(*pointer);
        if (value >= bottom && value < top) {
            *pointer = reinterpret_cast<const void*>(
                static_cast<unsigned char*>(stack_copy) + (value - bottom));
        }
    }
#endif

    // Rewrites possible pointers to locations within the stack to point to the
    // corresponding locations in the copy, and rewrites the non-volatile registers
    // in |context| likewise. This is necessary to handle stack frames with dynamic
    // stack allocation, where a pointer to the beginning of the dynamic allocation
    // area is stored on the stack and/or in a non-volatile register.
    //
    // Eager rewriting of anything that looks like a pointer to the stack, as done
    // in this function, does not adversely affect the stack unwinding. The only
    // other values on the stack the unwinding depends on are return addresses,
    // which should not point within the stack memory. The rewriting is guaranteed
    // to catch all pointers because the stacks are guaranteed by the ABI to be
    // sizeof(void*) aligned.
    //
    // Note: this function must not access memory in the original stack as it may
    // have been changed or deallocated by this point. This is why |top| and
    // |bottom| are passed as uintptr_t.
    void RewritePointersToStackMemory(uintptr_t top, uintptr_t bottom,
        CONTEXT* context, void* stack_copy)
    {
#if defined(_WIN64)
        DWORD64 CONTEXT::*const nonvolatile_registers[] = {
            &CONTEXT::R12,
            &CONTEXT::R13,
            &CONTEXT::R14,
            &CONTEXT::R15,
            &CONTEXT::Rdi,
            &CONTEXT::Rsi,
            &CONTEXT::Rbx,
            &CONTEXT::Rbp,
            &CONTEXT::Rsp
        };

        // Rewrite pointers in the context.
        for (size_t i = 0; i < arraysize(nonvolatile_registers); ++i) {
            DWORD64* const reg = &(context->*nonvolatile_registers[i]);
            RewritePointerIfInOriginalStack(top, bottom, stack_copy,
                reinterpret_cast<const void**>(reg));
        }

        // Rewrite pointers on the stack.
        const void** start = reinterpret_cast<const void**>(stack_copy);
        const void** end = reinterpret_cast<const void**>(
            reinterpret_cast<char*>(stack_copy) + (top - bottom));
        for (const void** loc = start; loc < end; ++loc)
            RewritePointerIfInOriginalStack(top, bottom, stack_copy, loc);
#endif
    }

    // Movable type representing a recorded stack frame.
    struct RecordedFrame {
        RecordedFrame() { }

        RecordedFrame(RecordedFrame&& other)
            : instruction_pointer(other.instruction_pointer)
            , module(std::move(other.module))
        {
        }

        RecordedFrame& operator=(RecordedFrame&& other)
        {
            instruction_pointer = other.instruction_pointer;
            module = std::move(other.module);
            return *this;
        }

        const void* instruction_pointer;
        ScopedModuleHandle module;

    private:
        DISALLOW_COPY_AND_ASSIGN(RecordedFrame);
    };

    // Walks the stack represented by |context| from the current frame downwards,
    // recording the instruction pointer and associated module for each frame in
    // |stack|.
    void RecordStack(CONTEXT* context, std::vector<RecordedFrame>* stack)
    {
#ifdef _WIN64
        DCHECK(stack->empty());

        // Reserve enough memory for most stacks, to avoid repeated
        // allocations. Approximately 99.9% of recorded stacks are 128 frames or
        // fewer.
        stack->reserve(128);

        Win32StackFrameUnwinder frame_unwinder;
        while (context->Rip) {
            const void* instruction_pointer = reinterpret_cast<const void*>(context->Rip);
            ScopedModuleHandle module;
            if (!frame_unwinder.TryUnwind(context, &module))
                return;
            RecordedFrame frame;
            frame.instruction_pointer = instruction_pointer;
            frame.module = std::move(module);
            stack->push_back(std::move(frame));
        }
#endif
    }

    // Gets the unique build ID for a module. Windows build IDs are created by a
    // concatenation of a GUID and AGE fields found in the headers of a module. The
    // GUID is stored in the first 16 bytes and the AGE is stored in the last 4
    // bytes. Returns the empty string if the function fails to get the build ID.
    //
    // Example:
    // dumpbin chrome.exe /headers | find "Format:"
    //   ... Format: RSDS, {16B2A428-1DED-442E-9A36-FCE8CBD29726}, 10, ...
    //
    // The resulting buildID string of this instance of chrome.exe is
    // "16B2A4281DED442E9A36FCE8CBD2972610".
    //
    // Note that the AGE field is encoded in decimal, not hex.
    std::string GetBuildIDForModule(HMODULE module_handle)
    {
        GUID guid;
        DWORD age;
        win::PEImage(module_handle).GetDebugId(&guid, &age);
        const int kGUIDSize = 39;
        std::wstring build_id;
        int result = ::StringFromGUID2(guid, WriteInto(&build_id, kGUIDSize), kGUIDSize);
        if (result != kGUIDSize)
            return std::string();
        RemoveChars(build_id, L"{}-", &build_id);
        build_id += StringPrintf(L"%d", age);
        return WideToUTF8(build_id);
    }

    // ScopedDisablePriorityBoost -------------------------------------------------

    // Disables priority boost on a thread for the lifetime of the object.
    class ScopedDisablePriorityBoost {
    public:
        ScopedDisablePriorityBoost(HANDLE thread_handle);
        ~ScopedDisablePriorityBoost();

    private:
        HANDLE thread_handle_;
        BOOL got_previous_boost_state_;
        BOOL boost_state_was_disabled_;

        DISALLOW_COPY_AND_ASSIGN(ScopedDisablePriorityBoost);
    };

    ScopedDisablePriorityBoost::ScopedDisablePriorityBoost(HANDLE thread_handle)
        : thread_handle_(thread_handle)
        , got_previous_boost_state_(false)
        , boost_state_was_disabled_(false)
    {
        got_previous_boost_state_ = ::GetThreadPriorityBoost(thread_handle_, &boost_state_was_disabled_);
        if (got_previous_boost_state_) {
            // Confusingly, TRUE disables priority boost.
            ::SetThreadPriorityBoost(thread_handle_, TRUE);
        }
    }

    ScopedDisablePriorityBoost::~ScopedDisablePriorityBoost()
    {
        if (got_previous_boost_state_)
            ::SetThreadPriorityBoost(thread_handle_, boost_state_was_disabled_);
    }

    // ScopedSuspendThread --------------------------------------------------------

    // Suspends a thread for the lifetime of the object.
    class ScopedSuspendThread {
    public:
        ScopedSuspendThread(HANDLE thread_handle);
        ~ScopedSuspendThread();

        bool was_successful() const { return was_successful_; }

    private:
        HANDLE thread_handle_;
        bool was_successful_;

        DISALLOW_COPY_AND_ASSIGN(ScopedSuspendThread);
    };

    ScopedSuspendThread::ScopedSuspendThread(HANDLE thread_handle)
        : thread_handle_(thread_handle)
        , was_successful_(::SuspendThread(thread_handle) != static_cast<DWORD>(-1))
    {
    }

    ScopedSuspendThread::~ScopedSuspendThread()
    {
        if (!was_successful_)
            return;

        // Disable the priority boost that the thread would otherwise receive on
        // resume. We do this to avoid artificially altering the dynamics of the
        // executing application any more than we already are by suspending and
        // resuming the thread.
        //
        // Note that this can racily disable a priority boost that otherwise would
        // have been given to the thread, if the thread is waiting on other wait
        // conditions at the time of SuspendThread and those conditions are satisfied
        // before priority boost is reenabled. The measured length of this window is
        // ~100us, so this should occur fairly rarely.
        ScopedDisablePriorityBoost disable_priority_boost(thread_handle_);
        bool resume_thread_succeeded = ::ResumeThread(thread_handle_) != static_cast<DWORD>(-1);
        CHECK(resume_thread_succeeded) << "ResumeThread failed: " << GetLastError();
    }

    // Tests whether |stack_pointer| points to a location in the guard page.
    //
    // IMPORTANT NOTE: This function is invoked while the target thread is
    // suspended so it must not do any allocation from the default heap, including
    // indirectly via use of DCHECK/CHECK or other logging statements. Otherwise
    // this code can deadlock on heap locks in the default heap acquired by the
    // target thread before it was suspended.
    bool PointsToGuardPage(uintptr_t stack_pointer)
    {
        MEMORY_BASIC_INFORMATION memory_info;
        SIZE_T result = ::VirtualQuery(reinterpret_cast<LPCVOID>(stack_pointer),
            &memory_info,
            sizeof(memory_info));
        return result != 0 && (memory_info.Protect & PAGE_GUARD);
    }

    // Suspends the thread with |thread_handle|, copies its stack and resumes the
    // thread, then records the stack frames and associated modules into |stack|.
    //
    // IMPORTANT NOTE: No allocations from the default heap may occur in the
    // ScopedSuspendThread scope, including indirectly via use of DCHECK/CHECK or
    // other logging statements. Otherwise this code can deadlock on heap locks in
    // the default heap acquired by the target thread before it was suspended.
    void SuspendThreadAndRecordStack(
        HANDLE thread_handle,
        const void* base_address,
        void* stack_copy_buffer,
        size_t stack_copy_buffer_size,
        std::vector<RecordedFrame>* stack,
        NativeStackSamplerTestDelegate* test_delegate)
    {
        DCHECK(stack->empty());

        CONTEXT thread_context = { 0 };
        thread_context.ContextFlags = CONTEXT_FULL;
        // The stack bounds are saved to uintptr_ts for use outside
        // ScopedSuspendThread, as the thread's memory is not safe to dereference
        // beyond that point.
        const uintptr_t top = reinterpret_cast<uintptr_t>(base_address);
        uintptr_t bottom = 0u;

        {
            ScopedSuspendThread suspend_thread(thread_handle);

            if (!suspend_thread.was_successful())
                return;

            if (!::GetThreadContext(thread_handle, &thread_context))
                return;
#if defined(_WIN64)
            bottom = thread_context.Rsp;
#else
            bottom = thread_context.Esp;
#endif

            if ((top - bottom) > stack_copy_buffer_size)
                return;

            // Dereferencing a pointer in the guard page in a thread that doesn't own
            // the stack results in a STATUS_GUARD_PAGE_VIOLATION exception and a crash.
            // This occurs very rarely, but reliably over the population.
            if (PointsToGuardPage(bottom))
                return;

            std::memcpy(stack_copy_buffer, reinterpret_cast<const void*>(bottom),
                top - bottom);
        }

        if (test_delegate)
            test_delegate->OnPreStackWalk();

        RewritePointersToStackMemory(top, bottom, &thread_context, stack_copy_buffer);

        RecordStack(&thread_context, stack);
    }

    // NativeStackSamplerWin ------------------------------------------------------

    class NativeStackSamplerWin : public NativeStackSampler {
    public:
        NativeStackSamplerWin(win::ScopedHandle thread_handle,
            NativeStackSamplerTestDelegate* test_delegate);
        ~NativeStackSamplerWin() override;

        // StackSamplingProfiler::NativeStackSampler:
        void ProfileRecordingStarting(
            std::vector<StackSamplingProfiler::Module>* modules) override;
        void RecordStackSample(StackSamplingProfiler::Sample* sample) override;
        void ProfileRecordingStopped() override;

    private:
        enum {
            // Intended to hold the largest stack used by Chrome. The default Win32
            // reserved stack size is 1 MB and Chrome Windows threads currently always
            // use the default, but this allows for expansion if it occurs. The size
            // beyond the actual stack size consists of unallocated virtual memory pages
            // so carries little cost (just a bit of wated address space).
            kStackCopyBufferSize = 2 * 1024 * 1024
        };

        // Attempts to query the module filename, base address, and id for
        // |module_handle|, and store them in |module|. Returns true if it succeeded.
        static bool GetModuleForHandle(HMODULE module_handle,
            StackSamplingProfiler::Module* module);

        // Gets the index for the Module corresponding to |module_handle| in
        // |modules|, adding it if it's not already present. Returns
        // StackSamplingProfiler::Frame::kUnknownModuleIndex if no Module can be
        // determined for |module|.
        size_t GetModuleIndex(HMODULE module_handle,
            std::vector<StackSamplingProfiler::Module>* modules);

        // Copies the information represented by |stack| into |sample| and |modules|.
        void CopyToSample(const std::vector<RecordedFrame>& stack,
            StackSamplingProfiler::Sample* sample,
            std::vector<StackSamplingProfiler::Module>* modules);

        win::ScopedHandle thread_handle_;

        NativeStackSamplerTestDelegate* const test_delegate_;

        // The stack base address corresponding to |thread_handle_|.
        const void* const thread_stack_base_address_;

        // Buffer to use for copies of the stack. We use the same buffer for all the
        // samples to avoid the overhead of multiple allocations and frees.
        const std::unique_ptr<unsigned char[]> stack_copy_buffer_;

        // Weak. Points to the modules associated with the profile being recorded
        // between ProfileRecordingStarting() and ProfileRecordingStopped().
        std::vector<StackSamplingProfiler::Module>* current_modules_;

        // Maps a module handle to the corresponding Module's index within
        // current_modules_.
        std::map<HMODULE, size_t> profile_module_index_;

        DISALLOW_COPY_AND_ASSIGN(NativeStackSamplerWin);
    };

    NativeStackSamplerWin::NativeStackSamplerWin(
        win::ScopedHandle thread_handle,
        NativeStackSamplerTestDelegate* test_delegate)
        : thread_handle_(thread_handle.Take())
        , test_delegate_(test_delegate)
        , thread_stack_base_address_(
              GetThreadEnvironmentBlock(thread_handle_.Get())->Tib.StackBase)
        , stack_copy_buffer_(new unsigned char[kStackCopyBufferSize])
    {
    }

    NativeStackSamplerWin::~NativeStackSamplerWin()
    {
    }

    void NativeStackSamplerWin::ProfileRecordingStarting(
        std::vector<StackSamplingProfiler::Module>* modules)
    {
        current_modules_ = modules;
        profile_module_index_.clear();
    }

    void NativeStackSamplerWin::RecordStackSample(
        StackSamplingProfiler::Sample* sample)
    {
        DCHECK(current_modules_);

        if (!stack_copy_buffer_)
            return;

        std::vector<RecordedFrame> stack;
        SuspendThreadAndRecordStack(thread_handle_.Get(), thread_stack_base_address_,
            stack_copy_buffer_.get(), kStackCopyBufferSize,
            &stack, test_delegate_);
        CopyToSample(stack, sample, current_modules_);
    }

    void NativeStackSamplerWin::ProfileRecordingStopped()
    {
        current_modules_ = nullptr;
    }

    // static
    bool NativeStackSamplerWin::GetModuleForHandle(
        HMODULE module_handle,
        StackSamplingProfiler::Module* module)
    {
        wchar_t module_name[MAX_PATH];
        DWORD result_length = GetModuleFileName(module_handle, module_name, arraysize(module_name));
        if (result_length == 0)
            return false;

        module->filename = base::FilePath(module_name);

        module->base_address = reinterpret_cast<uintptr_t>(module_handle);

        module->id = GetBuildIDForModule(module_handle);
        if (module->id.empty())
            return false;

        return true;
    }

    size_t NativeStackSamplerWin::GetModuleIndex(
        HMODULE module_handle,
        std::vector<StackSamplingProfiler::Module>* modules)
    {
        if (!module_handle)
            return StackSamplingProfiler::Frame::kUnknownModuleIndex;

        auto loc = profile_module_index_.find(module_handle);
        if (loc == profile_module_index_.end()) {
            StackSamplingProfiler::Module module;
            if (!GetModuleForHandle(module_handle, &module))
                return StackSamplingProfiler::Frame::kUnknownModuleIndex;
            modules->push_back(module);
            loc = profile_module_index_.insert(std::make_pair(
                                                   module_handle, modules->size() - 1))
                      .first;
        }

        return loc->second;
    }

    void NativeStackSamplerWin::CopyToSample(
        const std::vector<RecordedFrame>& stack,
        StackSamplingProfiler::Sample* sample,
        std::vector<StackSamplingProfiler::Module>* modules)
    {
        sample->clear();
        sample->reserve(stack.size());

        for (const RecordedFrame& frame : stack) {
            sample->push_back(StackSamplingProfiler::Frame(
                reinterpret_cast<uintptr_t>(frame.instruction_pointer),
                GetModuleIndex(frame.module.Get(), modules)));
        }
    }

} // namespace

std::unique_ptr<NativeStackSampler> NativeStackSampler::Create(
    PlatformThreadId thread_id,
    NativeStackSamplerTestDelegate* test_delegate)
{
#if _WIN64
    // Get the thread's handle.
    HANDLE thread_handle = ::OpenThread(
        THREAD_GET_CONTEXT | THREAD_SUSPEND_RESUME | THREAD_QUERY_INFORMATION,
        FALSE,
        thread_id);

    if (thread_handle) {
        return std::unique_ptr<NativeStackSampler>(new NativeStackSamplerWin(
            win::ScopedHandle(thread_handle), test_delegate));
    }
#endif
    return std::unique_ptr<NativeStackSampler>();
}

} // namespace base
